import torch  # 必须先导入torch
import ops_test_op


# pytorch打包,并添加前向和反向传播
class OpsTest(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input1: torch.Tensor, input2: torch.Tensor):
        output = torch.zeros_like(input1)
        ops_test_op.test(input1, input2, output)
        return output

    @staticmethod
    def backward(ctx, output):
        input1_grad = output.clone()
        input2_grad = output.clone()
        return input1_grad, input2_grad
